#include "user_rights.h"
#include <LM.h>
#include <WtsApi32.h>
#include "debug.h"
#include "error.h"
#include "logging.h"
#include "scoped_any.h"
#include "scope_guard.h"
#include "system_info.h"
#include "vistautil.h"

bool UserRights::TokenIsAdmin(HANDLE token)
{
	return BelongsToGroup(token, DOMAIN_ALIAS_RID_ADMINS);
}
bool UserRights::UserIsAdmin()
{
	return BelongsToGroup(NULL, DOMAIN_ALIAS_RID_ADMINS);
}
bool UserRights::UserIsUser()
{
	return BelongsToGroup(NULL, DOMAIN_ALIAS_RID_USERS);
}

bool UserRights::UserIsPowerUser()
{
	return BelongsToGroup(NULL, DOMAIN_ALIAS_RID_POWER_USERS);
}

bool UserRights::UserIsGuest()
{
	return BelongsToGroup(NULL, DOMAIN_ALIAS_RID_GUESTS);
}

bool UserRights::BelongsToGroup(HANDLE token, int group_id)
{
	SID_IDENTIFIER_AUTHORITY nt_authority = SECURITY_NT_AUTHORITY;
	PSID group = NULL;

	BOOL check = ::AllocateAndInitializeSid(&nt_authority, 2,
		SECURITY_BUILTIN_DOMAIN_RID, group_id, 0, 0, 0, 0, 0, 0, &group);
	if (check)
	{
		if (!::CheckTokenMembership(token, group, &check))
			check = false;
		::FreeSid(group);
	}
	return !!check;
}

bool UserRights::UserIsRestricted()
{
	scoped_handle token;
	if (!::OpenProcessToken(::GetCurrentProcess(), TOKEN_QUERY, address(token))) 
	{
		UTIL_LOG(LE, (_T("[UserRights::UserIsRestricted - OpenProcessToken failed]")
			_T("[0x%08x]"), HRESULTFromLastError()));
		return true;
	}

	return !!::IsTokenRestricted(get(token));
}

bool UserRights::UserIsLowOrUntrustedIntegrity() 
{
	if (SystemInfo::IsRunningOnVistaOrLater()) 
	{
		MANDATORY_LEVEL integrity_level = MandatoryLevelUntrusted;
		if (FAILED(vista_util::GetProcessIntegrityLevel(0, &integrity_level)) ||
			integrity_level == MandatoryLevelUntrusted ||
			integrity_level == MandatoryLevelLow) {
				return true;
		}
	}

	return false;
}

HRESULT UserRights::UserIsLoggedOnInteractively(bool* is_logged_on)
{
	ASSERT1(is_logged_on);

	*is_logged_on = false;

	HRESULT hr = S_OK;

	// Get the user associated with the current process.
	WKSTA_USER_INFO_1* user_info = NULL;
	NET_API_STATUS status = ::NetWkstaUserGetInfo(
		NULL,
		1,
		reinterpret_cast<uint8**>(&user_info));
	if (status != NERR_Success || user_info == NULL) {
		UTIL_LOG(LE, (_T("[NetWkstaUserGetInfo failed][%u]"), status));
		return HRESULT_FROM_WIN32(status);
	}
	ON_SCOPE_EXIT(::NetApiBufferFree, user_info);

	UTIL_LOG(L2, (_T("[wks domain=%s][wks user=%s]"),
		user_info->wkui1_logon_domain, user_info->wkui1_username));

	PWTS_SESSION_INFOW session_info = NULL;
	const DWORD kVersion = 1;
	DWORD num_sessions = 0;
	if (!::WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE,
		0,
		kVersion,
		&session_info,
		&num_sessions)) {
			hr = HRESULTFromLastError();
			UTIL_LOG(LE, (_T("[WTSEnumerateSessions failed][0x%08x]"), hr));
			return hr;
	}
	ON_SCOPE_EXIT(::WTSFreeMemory, session_info);

	// Loop through all active sessions to see whether one of the sessions
	// belongs to current user. If so, regard this user as "logged-on".
	for (DWORD i = 0; i < num_sessions; ++i) {
		TCHAR* domain_name = NULL;
		DWORD domain_name_len = 0;
		if (!::WTSQuerySessionInformation(WTS_CURRENT_SERVER_HANDLE,
			session_info[i].SessionId,
			WTSDomainName,
			&domain_name,
			&domain_name_len)) {
				hr = HRESULTFromLastError();
				UTIL_LOG(LE, (_T("[WTSQuerySessionInformation failed][0x%08x]"), hr));
				continue;
		}
		ON_SCOPE_EXIT(::WTSFreeMemory, domain_name);

		TCHAR* user_name = NULL;
		DWORD user_name_len = 0;
		if (!::WTSQuerySessionInformation(WTS_CURRENT_SERVER_HANDLE,
			session_info[i].SessionId,
			WTSUserName,
			&user_name,
			&user_name_len)) {
				hr = HRESULTFromLastError();
				UTIL_LOG(LE, (_T("[WTSQuerySessionInformation failed][0x%08x]"), hr));
				continue;
		}
		ON_SCOPE_EXIT(::WTSFreeMemory, user_name);

		UTIL_LOG(L2, (_T("[ts domain=%s][ts user=%s][station=%s]"),
			domain_name,
			user_name,
			session_info[i].pWinStationName));

		// Occasionally, the domain name and user name could not be retrieved when
		// the program is started just at logon time.
		if (!(domain_name && *domain_name && user_name && *user_name)) {
			hr = E_FAIL;
			continue;
		}

		if (_tcsicmp(user_info->wkui1_logon_domain, domain_name) == 0 &&
			_tcsicmp(user_info->wkui1_username, user_name) == 0) {
				*is_logged_on = true;
				return S_OK;
		}
	}

	return hr;
}

// Returns a token with TOKEN_ALL_ACCESS rights. At the moment, we only require
// TOKEN_QUERY | TOKEN_ASSIGN_PRIMARY, but requirements may change in the
// future.
HRESULT UserRights::GetCallerToken(CAccessToken* token) {
	ASSERT1(token);

	CComPtr<IUnknown> security_context;
	HRESULT hr = ::CoGetCallContext(IID_PPV_ARGS(&security_context));
	if (SUCCEEDED(hr)) {
		return token->OpenCOMClientToken(TOKEN_ALL_ACCESS) ? S_OK :
			HRESULTFromLastError();
	} else if (hr != RPC_E_CALL_COMPLETE) {
		UTIL_LOG(LE, (_T("[::CoGetCallContext failed][0x%x]"), hr));
		return hr;
	}

	// RPC_E_CALL_COMPLETE indicates an in-proc intra-apartment call. Return the
	// current process token.
	return token->OpenThreadToken(TOKEN_ALL_ACCESS) ? S_OK :
		HRESULTFromLastError();
}

bool UserRights::VerifyCallerIsAdmin() {
	CAccessToken impersonated_token;
	if (FAILED(GetCallerToken(&impersonated_token))) {
		return false;
	}
	return TokenIsAdmin(impersonated_token.GetHandle());
}

bool UserRights::VerifyCallerIsSystem() {
	CAccessToken impersonated_token;
	if (FAILED(GetCallerToken(&impersonated_token))) {
		return false;
	}

	CSid sid;
	if (!impersonated_token.GetUser(&sid)) {
		return false;
	}

	return sid == Sids::System();
}